import io
import os
import platform
import re
import sqlite3
from collections import OrderedDict
from datetime import datetime
from random import choice
from subprocess import check_output
from time import time

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "settings")
import django
django.setup()

import matplotlib.pyplot as plt
import pandas as pd
import psycopg2
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.cache import caches
from django.db import connection, connections
from django.test.utils import CaptureQueriesContext, override_settings
from django.utils.encoding import force_text
from MySQLdb import _mysql

import cachalot
from cachalot.api import invalidate
from cachalot.tests.models import Test


RESULTS_PATH = f"benchmark/docs/{datetime.now().date()}/"
CONTEXTS = ("Control", "Cold cache", "Hot cache")
DIVIDER = "divider"

LINUX_DATA_PATH = "/var/lib/"
DISK_DATA_RE = re.compile(r'^MODEL="(.*)" MOUNTPOINT="(.*)"$')


def get_disk_model_for_path_linux(path):
    out = force_text(check_output(["lsblk", "-Po", "MODEL,MOUNTPOINT"]))
    mount_points = []
    previous_model = None
    for model, mount_point in [
        DISK_DATA_RE.match(line).groups() for line in out.split("\n") if line
    ]:
        if model:
            previous_model = model.strip()
        if mount_point:
            mount_points.append((previous_model, mount_point))
    mount_points = sorted(mount_points, key=lambda t: -len(t[1]))
    for model, mount_point in mount_points:
        if path.startswith(mount_point):
            return model


def write_conditions():
    versions = OrderedDict()
    distribution = platform.uname()

    # Linux
    if distribution.system == "Linux":
        # CPU
        with open("/proc/cpuinfo") as f:
            versions["CPU"] = re.search(
                r"^model name\s+: (.+)$", f.read(), flags=re.MULTILINE
            ).group(1)
        # RAM
        with open("/proc/meminfo") as f:
            versions["RAM"] = re.search(
                r"^MemTotal:\s+(.+)$", f.read(), flags=re.MULTILINE
            ).group(1)
        # Disk Model
        versions.update((("Disk", get_disk_model_for_path_linux(LINUX_DATA_PATH)),))
        # OS
        versions["Linux distribution"] = f"{distribution.system} {distribution.release}"
    # Darwin
    else:
        # CPU
        versions["CPU"] = os.popen("sysctl -n machdep.cpu.brand_string").read().rstrip("\n")
        # RAM
        versions["RAM"] = os.popen("sysctl -n hw.memsize").read().rstrip("\n")
        # Disk Model
        versions["DISK"] = os.popen(
            "diskutil info /dev/disk0 | grep 'Device / Media Name'"
        ).read().split(":")[1].rstrip("\n").lstrip(" ")
        # OS
        versions["OS"] = f"{distribution.system} {distribution.release}"

    versions.update(
        (
            ("Python", platform.python_version()),
            ("Django", django.__version__),
            ("cachalot", cachalot.__version__),
            ("sqlite", sqlite3.sqlite_version),
        )
    )
    # PostgreSQL
    try:
        with connections["postgresql"].cursor() as cursor:
            cursor.execute("SELECT version();")
            versions["PostgreSQL"] = re.match(
                r"^PostgreSQL\s+(\S+)\s", cursor.fetchone()[0]
            ).group(1)
    except django.db.utils.OperationalError:
        raise django.db.utils.OperationalError(
            "You need a PostgreSQL DB called \"cachalot\" first. "
            "Login with \"psql -U postgres -h localhost\" and run: "
            "CREATE DATABASE cachalot;"
        )
    # MySQL
    try:
        with connections["mysql"].cursor() as cursor:
            cursor.execute("SELECT version();")
            versions["MySQL"] = cursor.fetchone()[0].split("-")[0]
    except django.db.utils.OperationalError:
        raise django.db.utils.OperationalError(
            "You need a MySQL DB called \"cachalot\" first. "
            "Login with \"mysql -u root\" and run: CREATE DATABASE cachalot;"
        )
    # Redis
    out = force_text(check_output(["redis-cli", "INFO", "server"])).replace("\r", "")
    versions["Redis"] = re.search(
        r"^redis_version:([\d\.]+)$", out, flags=re.MULTILINE
    ).group(1)
    # memcached
    out = force_text(check_output(["memcached", "-h"]))
    versions["memcached"] = re.match(
        r"^memcached ([\d\.]+)$", out, flags=re.MULTILINE
    ).group(1)

    versions.update(
        (
            ("psycopg2", psycopg2.__version__.split()[0]),
            ("mysqlclient", _mysql.__version__),
        )
    )

    with io.open(os.path.join(RESULTS_PATH, "conditions.rst"), "w") as f:
        f.write(
            "In this benchmark, a small database is generated, "
            "and each test is executed %s times "
            "under the following conditions:\n\n" % Benchmark.n
        )

        def write_table_sep(char="="):
            f.write((char * 20) + " " + (char * 50) + "\n")

        write_table_sep()
        for k, v in versions.items():
            f.write(k.ljust(20) + " " + v + "\n")
        write_table_sep()


class AssertNumQueries(CaptureQueriesContext):
    def __init__(self, n, using=None):
        self.n = n
        self.using = using
        super(AssertNumQueries, self).__init__(self.get_connection())

    def get_connection(self):
        if self.using is None:
            return connection
        return connections[self.using]

    def __exit__(self, exc_type, exc_val, exc_tb):
        super(AssertNumQueries, self).__exit__(exc_type, exc_val, exc_tb)
        if len(self) != self.n:
            print(
                "The amount of queries should be %s, but %s were captured."
                % (self.n, len(self))
            )


class Benchmark(object):
    n = 20

    def __init__(self):
        self.data = []

    def bench_once(self, context, num_queries, invalidate_before=False):
        for _ in range(self.n):
            if invalidate_before:
                invalidate(db_alias=self.db_alias)
            with AssertNumQueries(num_queries, using=self.db_alias):
                start = time()
                self.query_function(self.db_alias)
                end = time()
            self.data.append(
                {
                    "query": self.query_name,
                    "time": end - start,
                    "context": context,
                    "db": self.db_vendor,
                    "cache": self.cache_name,
                }
            )

    def benchmark(self, query_str, to_list=True, num_queries=1):
        # Clears the cache before a single benchmark to ensure the same
        # conditions across single benchmarks.
        caches[settings.CACHALOT_CACHE].clear()

        self.query_name = query_str
        query_str = "Test.objects.using(using)" + query_str
        if to_list:
            query_str = "list(%s)" % query_str
        self.query_function = eval("lambda using: " + query_str)

        with override_settings(CACHALOT_ENABLED=False):
            self.bench_once(CONTEXTS[0], num_queries)

        self.bench_once(CONTEXTS[1], num_queries, invalidate_before=True)

        self.bench_once(CONTEXTS[2], 0)

    def execute_benchmark(self):
        self.benchmark(".count()", to_list=False)
        self.benchmark(".first()", to_list=False)
        self.benchmark("[:10]")
        self.benchmark("[5000:5010]")
        self.benchmark(".filter(name__icontains='e')[0:10]")
        self.benchmark(".filter(name__icontains='e')[5000:5010]")
        self.benchmark(".order_by('owner')[0:10]")
        self.benchmark(".order_by('owner')[5000:5010]")
        self.benchmark(".select_related('owner')[0:10]")
        self.benchmark(".select_related('owner')[5000:5010]")
        self.benchmark(".prefetch_related('owner__groups')[0:10]", num_queries=3)
        self.benchmark(".prefetch_related('owner__groups')[5000:5010]", num_queries=3)

    def run(self):
        for db_alias in settings.DATABASES:
            self.db_alias = db_alias
            self.db_vendor = connections[self.db_alias].vendor
            print("Benchmarking %s…" % self.db_vendor)
            for cache_alias in settings.CACHES:
                cache = caches[cache_alias]
                self.cache_name = cache.__class__.__name__[:-5].lower()
                with override_settings(CACHALOT_CACHE=cache_alias):
                    self.execute_benchmark()

        self.df = pd.DataFrame.from_records(self.data)
        if not os.path.exists(RESULTS_PATH):
            os.mkdir(RESULTS_PATH)
        self.df.to_csv(os.path.join(RESULTS_PATH, "data.csv"))

        self.xlim = (0, self.df["time"].max() * 1.01)
        self.output("db")
        self.output("cache")

    def output(self, param):
        gp = self.df.groupby(["context", "query", param])["time"]
        self.means = gp.mean().unstack().unstack().reindex(CONTEXTS)
        los = self.means - gp.min().unstack().unstack().reindex(CONTEXTS)
        ups = gp.max().unstack().unstack().reindex(CONTEXTS) - self.means
        self.errors = dict(
            (
                key,
                dict(
                    (
                        subkey,
                        [
                            [los[key][subkey][context] for context in self.means.index],
                            [ups[key][subkey][context] for context in self.means.index],
                        ],
                    )
                    for subkey in self.means.columns.levels[1]
                ),
            )
            for key in self.means.columns.levels[0]
        )
        self.get_perfs(param)
        self.plot_detail(param)

        gp = self.df.groupby(["context", param])["time"]
        self.means = gp.mean().unstack().reindex(CONTEXTS)
        los = self.means - gp.min().unstack().reindex(CONTEXTS)
        ups = gp.max().unstack().reindex(CONTEXTS) - self.means
        self.errors = [
            [
                [los[key][context] for context in self.means.index],
                [ups[key][context] for context in self.means.index],
            ]
            for key in self.means
        ]
        self.plot_general(param)

    def get_perfs(self, param):
        with io.open(os.path.join(RESULTS_PATH, param + "_results.rst"), "w") as f:
            for v in self.means.columns.levels[0]:
                g = self.means[v].mean(axis=1)
                perf = "%s is %.1f× slower then %.1f× faster" % (
                    v.ljust(10),
                    g[CONTEXTS[1]] / g[CONTEXTS[0]],
                    g[CONTEXTS[0]] / g[CONTEXTS[2]],
                )
                print(perf)
                f.write("- %s\n" % perf)

    def plot_detail(self, param):
        for v in self.means.columns.levels[0]:
            plt.figure()
            axes = self.means[v].plot(
                kind="barh",
                xerr=self.errors[v],
                xlim=self.xlim,
                figsize=(15, 15),
                subplots=True,
                layout=(6, 2),
                sharey=True,
                legend=False,
            )
            plt.gca().invert_yaxis()
            for row in axes:
                for ax in row:
                    ax.xaxis.grid(True)
                    ax.set_ylabel("")
                    ax.set_xlabel("Time (s)")
            plt.savefig(os.path.join(RESULTS_PATH, "%s_%s.svg" % (param, v)))

    def plot_general(self, param):
        plt.figure()
        ax = self.means.plot(kind="barh", xerr=self.errors, xlim=self.xlim)
        ax.invert_yaxis()
        ax.xaxis.grid(True)
        ax.set_ylabel("")
        ax.set_xlabel("Time (s)")
        plt.savefig(os.path.join(RESULTS_PATH, "%s.svg" % param))


def create_data(using):
    User.objects.using(using).bulk_create(
        [User(username="user%d" % i) for i in range(50)]
    )
    Group.objects.using(using).bulk_create(
        [Group(name="test%d" % i) for i in range(10)]
    )
    groups = list(Group.objects.using(using))
    for u in User.objects.using(using):
        u.groups.add(choice(groups), choice(groups))
    users = list(User.objects.using(using))
    Test.objects.using(using).bulk_create(
        [Test(name="test%d" % i, owner=choice(users)) for i in range(10000)]
    )


if __name__ == "__main__":
    if not os.path.exists(RESULTS_PATH):
        os.mkdir(RESULTS_PATH)

    write_conditions()

    old_db_names = {}
    for alias in connections:
        conn = connections[alias]
        old_db_names[alias] = conn.settings_dict["NAME"]
        conn.creation.create_test_db(autoclobber=True)

        print("Populating %s…" % connections[alias].vendor)
        create_data(alias)

    Benchmark().run()

    for alias in connections:
        connections[alias].creation.destroy_test_db(old_db_names[alias])
